{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Text classification with Transformer\n",
    "\n",
    "**Author:** [Apoorv Nandan](https://twitter.com/NandanApoorv)<br>\n",
    "**Date created:** 2020/05/10<br>\n",
    "**Last modified:** 2020/05/10<br>\n",
    "**Description:** Implement a Transformer block as a Keras layer and use it for text classification."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Setup\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Implement multi head self attention as a Keras layer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "class MultiHeadSelfAttention(layers.Layer):\n",
    "    def __init__(self, embed_dim, num_heads=8):\n",
    "        super(MultiHeadSelfAttention, self).__init__()\n",
    "        self.embed_dim = embed_dim\n",
    "        self.num_heads = num_heads\n",
    "        if embed_dim % num_heads != 0:\n",
    "            raise ValueError(\n",
    "                f\"embedding dimension = {embed_dim} should be divisible by number of heads = {num_heads}\"\n",
    "            )\n",
    "        self.projection_dim = embed_dim // num_heads\n",
    "        self.query_dense = layers.Dense(embed_dim)\n",
    "        self.key_dense = layers.Dense(embed_dim)\n",
    "        self.value_dense = layers.Dense(embed_dim)\n",
    "        self.combine_heads = layers.Dense(embed_dim)\n",
    "\n",
    "    def attention(self, query, key, value):\n",
    "        score = tf.matmul(query, key, transpose_b=True)\n",
    "        dim_key = tf.cast(tf.shape(key)[-1], tf.float32)\n",
    "        scaled_score = score / tf.math.sqrt(dim_key)\n",
    "        weights = tf.nn.softmax(scaled_score, axis=-1)\n",
    "        output = tf.matmul(weights, value)\n",
    "        return output, weights\n",
    "\n",
    "    def separate_heads(self, x, batch_size):\n",
    "        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))\n",
    "        return tf.transpose(x, perm=[0, 2, 1, 3])\n",
    "\n",
    "    def call(self, inputs):\n",
    "        # x.shape = [batch_size, seq_len, embedding_dim]\n",
    "        batch_size = tf.shape(inputs)[0]\n",
    "        query = self.query_dense(inputs)  # (batch_size, seq_len, embed_dim)\n",
    "        key = self.key_dense(inputs)  # (batch_size, seq_len, embed_dim)\n",
    "        value = self.value_dense(inputs)  # (batch_size, seq_len, embed_dim)\n",
    "        query = self.separate_heads(\n",
    "            query, batch_size\n",
    "        )  # (batch_size, num_heads, seq_len, projection_dim)\n",
    "        key = self.separate_heads(\n",
    "            key, batch_size\n",
    "        )  # (batch_size, num_heads, seq_len, projection_dim)\n",
    "        value = self.separate_heads(\n",
    "            value, batch_size\n",
    "        )  # (batch_size, num_heads, seq_len, projection_dim)\n",
    "        attention, weights = self.attention(query, key, value)\n",
    "        attention = tf.transpose(\n",
    "            attention, perm=[0, 2, 1, 3]\n",
    "        )  # (batch_size, seq_len, num_heads, projection_dim)\n",
    "        concat_attention = tf.reshape(\n",
    "            attention, (batch_size, -1, self.embed_dim)\n",
    "        )  # (batch_size, seq_len, embed_dim)\n",
    "        output = self.combine_heads(\n",
    "            concat_attention\n",
    "        )  # (batch_size, seq_len, embed_dim)\n",
    "        return output\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Implement a Transformer block as a layer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "class TransformerBlock(layers.Layer):\n",
    "    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):\n",
    "        super(TransformerBlock, self).__init__()\n",
    "        self.att = MultiHeadSelfAttention(embed_dim, num_heads)\n",
    "        self.ffn = keras.Sequential(\n",
    "            [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n",
    "        )\n",
    "        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n",
    "        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n",
    "        self.dropout1 = layers.Dropout(rate)\n",
    "        self.dropout2 = layers.Dropout(rate)\n",
    "\n",
    "    def call(self, inputs, training):\n",
    "        attn_output = self.att(inputs)\n",
    "        attn_output = self.dropout1(attn_output, training=training)\n",
    "        out1 = self.layernorm1(inputs + attn_output)\n",
    "        ffn_output = self.ffn(out1)\n",
    "        ffn_output = self.dropout2(ffn_output, training=training)\n",
    "        return self.layernorm2(out1 + ffn_output)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Implement embedding layer\n",
    "\n",
    "Two seperate embedding layers, one for tokens, one for token index (positions).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "class TokenAndPositionEmbedding(layers.Layer):\n",
    "    def __init__(self, maxlen, vocab_size, emded_dim):\n",
    "        super(TokenAndPositionEmbedding, self).__init__()\n",
    "        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=emded_dim)\n",
    "        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=emded_dim)\n",
    "\n",
    "    def call(self, x):\n",
    "        maxlen = tf.shape(x)[-1]\n",
    "        positions = tf.range(start=0, limit=maxlen, delta=1)\n",
    "        positions = self.pos_emb(positions)\n",
    "        x = self.token_emb(x)\n",
    "        return x + positions\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Download and prepare dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "vocab_size = 20000  # Only consider the top 20k words\n",
    "maxlen = 200  # Only consider the first 200 words of each movie review\n",
    "(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)\n",
    "print(len(x_train), \"Training sequences\")\n",
    "print(len(x_val), \"Validation sequences\")\n",
    "x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)\n",
    "x_val = keras.preprocessing.sequence.pad_sequences(x_val, maxlen=maxlen)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Create classifier model using transformer layer\n",
    "\n",
    "Transformer layer outputs one vector for each time step of our input sequence.\n",
    "Here, we take the mean across all time steps and\n",
    "use a feed forward network on top of it to classify text.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "embed_dim = 32  # Embedding size for each token\n",
    "num_heads = 2  # Number of attention heads\n",
    "ff_dim = 32  # Hidden layer size in feed forward network inside transformer\n",
    "\n",
    "inputs = layers.Input(shape=(maxlen,))\n",
    "embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)\n",
    "x = embedding_layer(inputs)\n",
    "transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)\n",
    "x = transformer_block(x)\n",
    "x = layers.GlobalAveragePooling1D()(x)\n",
    "x = layers.Dropout(0.1)(x)\n",
    "x = layers.Dense(20, activation=\"relu\")(x)\n",
    "x = layers.Dropout(0.1)(x)\n",
    "outputs = layers.Dense(2, activation=\"softmax\")(x)\n",
    "\n",
    "model = keras.Model(inputs=inputs, outputs=outputs)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Train and Evaluate\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.compile(\"adam\", \"sparse_categorical_crossentropy\", metrics=[\"accuracy\"])\n",
    "history = model.fit(\n",
    "    x_train, y_train, batch_size=32, epochs=2, validation_data=(x_val, y_val)\n",
    ")\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "text_classification_with_transformer",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}